from typing import Dict, List
import numpy as np
from .losses import (
    LossWrapper,
    GradientBalancer,
    relative_norm_mse,
)
from .integrals import FluxIntegral
from .train import train_step_autoencoder, train_step_peft
from .peft_utils import (
    create_lora_model,
    get_target_modules_for_lora,
    save_lora_weights,
    save_peft_weights,
    load_lora_weights,
    setup_peft_stage,
    freeze_base_parameters,
)


def aggregate_dataset_stats(file_paths: List[str]) -> Dict[str, float]:
    """
    Aggregate statistics across multiple dataset files to get true dataset-wide statistics.
    This is the correct way to handle statistics for multi-file datasets.
    """
    import h5py
    from utils import RunningMeanStd

    # Initialize running statistics
    phi_stats = RunningMeanStd((1,))
    flux_stats = RunningMeanStd((1,))

    total_samples = 0

    for file_path in file_paths:
        try:
            with h5py.File(file_path, "r") as f:
                if "metadata" not in f:
                    continue

                metadata = f["metadata"]

                # Get number of samples in this file
                if "data" in f:
                    n_samples = len(
                        [k for k in f["data"].keys() if k.startswith("timestep_")]
                    )
                else:
                    n_samples = len(metadata["timesteps"][()])

                # Load per-file statistics
                if "phi_mean" in metadata and "phi_std" in metadata:
                    phi_mean = metadata["phi_mean"][()]
                    phi_std = metadata["phi_std"][()]
                    phi_var = phi_std**2

                    # phi has shape (2, 1, 1, 1) for [real, imaginary] channels
                    # For integral loss normalization, use the magnitude (combined statistics)
                    if phi_mean.shape[0] == 2:  # separate real/imaginary channels
                        # Compute magnitude statistics: sqrt(real^2 + imag^2)
                        # For mean: use RMS of both channels
                        phi_mean_combined = np.sqrt(np.mean(phi_mean**2))
                        # For variance: combine variances assuming independence
                        phi_var_combined = np.mean(
                            phi_var
                        )  # average variance across channels
                    else:
                        phi_mean_combined = (
                            float(phi_mean)
                            if np.isscalar(phi_mean)
                            else float(phi_mean.item())
                        )
                        phi_var_combined = (
                            float(phi_var)
                            if np.isscalar(phi_var)
                            else float(phi_var.item())
                        )

                    # Update running statistics (weighted by number of samples)
                    phi_stats.update_from_moments(
                        batch_mean=np.array([phi_mean_combined]),
                        batch_var=np.array([phi_var_combined]),
                        batch_min=np.array(
                            [phi_mean_combined]
                        ),  # Using mean as min/max
                        batch_max=np.array([phi_mean_combined]),
                        batch_count=float(n_samples),
                    )

                if "flux_mean" in metadata and "flux_std" in metadata:
                    flux_mean = metadata["flux_mean"][()]
                    flux_std = metadata["flux_std"][()]
                    flux_var = flux_std**2

                    flux_mean = (
                        float(flux_mean)
                        if np.isscalar(flux_mean)
                        else float(flux_mean.item())
                    )
                    flux_var = (
                        float(flux_var)
                        if np.isscalar(flux_var)
                        else float(flux_var.item())
                    )

                    flux_stats.update_from_moments(
                        batch_mean=np.array([flux_mean]),
                        batch_var=np.array([flux_var]),
                        batch_min=np.array([flux_mean]),  # Using mean as min/max
                        batch_max=np.array([flux_mean]),
                        batch_count=float(n_samples),
                    )

                total_samples += n_samples

        except Exception as e:
            print(f"Warning: Could not process {file_path}: {e}")
            continue

    # Extract final aggregated statistics
    aggregated_stats = {}
    if phi_stats.count > 0:
        aggregated_stats["phi_mean"] = float(phi_stats.mean.item())
        aggregated_stats["phi_std"] = float(np.sqrt(phi_stats.var).item())

    if flux_stats.count > 0:
        aggregated_stats["flux_mean"] = float(flux_stats.mean.item())
        aggregated_stats["flux_std"] = float(np.sqrt(flux_stats.var).item())

    # print(f"Aggregated statistics from {len(file_paths)} files, {total_samples} total samples")

    return aggregated_stats


__all__ = [
    "LossWrapper",
    "GradientBalancer",
    "relative_norm_mse",
    "FluxIntegral",
    "train_step_autoencoder",
    "aggregate_dataset_stats",
]
